#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
交互式股票预测程序
结合Akshare数据下载和Kronos模型进行股票预测
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import akshare as ak
import os
import sys
from datetime import datetime, timedelta
import warnings
import holidays
warnings.filterwarnings('ignore')

# 添加项目根目录到路径
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)

from model import Kronos, KronosTokenizer, KronosPredictor

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class InteractiveStockPredictor:
    """交互式股票预测器"""
    
    def __init__(self):
        """初始化预测器"""
        self.predictor = None
        self.model = None
        self.tokenizer = None
        self.device = "cuda:0" if self._check_cuda() else "cpu"
        print(f"使用设备: {self.device}")
        
        # 初始化中国节假日
        self.cn_holidays = holidays.China()
        
    def _check_cuda(self):
        """检查CUDA是否可用"""
        try:
            import torch
            return torch.cuda.is_available()
        except ImportError:
            return False
    
    def is_trading_day(self, date):
        """判断是否为交易日（排除周末和节假日）"""
        # 排除周末
        if date.weekday() >= 5:  # 5=周六, 6=周日
            return False
        # 排除节假日
        if date in self.cn_holidays:
            return False
        return True
    
    def generate_trading_days(self, start_date, num_days):
        """生成指定数量的交易日"""
        trading_days = []
        current_date = start_date
        
        while len(trading_days) < num_days:
            if self.is_trading_day(current_date):
                trading_days.append(current_date)
            current_date += timedelta(days=1)
        
        return trading_days
    
    def load_models(self):
        """加载Kronos模型和分词器"""
        try:
            print("正在加载Kronos模型...")
            self.tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
            self.model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
            self.predictor = KronosPredictor(
                model=self.model, 
                tokenizer=self.tokenizer, 
                device=self.device, 
                max_context=512
            )
            print("✅ 模型加载成功!")
            return True
        except Exception as e:
            print(f"❌ 模型加载失败: {str(e)}")
            return False
    
    def get_stock_codes(self):
        """获取股票代码列表"""
        print("\n" + "="*60)
        print("股票代码输入方式")
        print("="*60)
        print("1. 手动输入股票代码")
        print("2. 从TXT文件读取股票代码列表")
        print()
        
        while True:
            choice = input("请选择输入方式 (1/2): ").strip()
            
            if choice == '1':
                return self._get_manual_codes()
            elif choice == '2':
                return self._get_codes_from_file()
            else:
                print("❌ 请输入 1 或 2")
    
    def _get_manual_codes(self):
        """手动输入股票代码"""
        print("\n手动输入股票代码")
        print("-" * 30)
        print("格式说明:")
        print("- 多个股票代码用逗号分隔")
        print("- 支持A股代码格式: 600030, 002261, 688326, 300364")
        print("- 示例: 600030,002261")
        print()
        
        while True:
            stock_input = input("请输入股票代码: ").strip()
            if not stock_input:
                print("❌ 请输入有效的股票代码")
                continue
            
            # 解析股票代码
            stock_codes = [code.strip() for code in stock_input.split(',')]
            stock_codes = [code for code in stock_codes if code]
            
            if not stock_codes:
                print("❌ 请输入有效的股票代码")
                continue
            
            # 验证股票代码格式
            valid_codes = []
            for code in stock_codes:
                if code.isdigit() and len(code) == 6:
                    valid_codes.append(code)
                else:
                    print(f"⚠️  股票代码 {code} 格式不正确，已跳过")
            
            if not valid_codes:
                print("❌ 没有有效的股票代码")
                continue
            
            return valid_codes
    
    def _get_codes_from_file(self):
        """从TXT文件读取股票代码"""
        print("\n从TXT文件读取股票代码")
        print("-" * 30)
        print("文件格式说明:")
        print("- 每行一个股票代码")
        print("- 支持注释行（以#开头）")
        print("- 示例文件内容:")
        print("  # 这是注释行")
        print("  600030")
        print("  002261")
        print()
        
        while True:
            filename = input("请输入TXT文件名 (例如: stock_codes.txt): ").strip()
            if not filename:
                print("❌ 请输入文件名")
                continue
            
            # 如果用户没有输入扩展名，自动添加.txt
            if not filename.endswith('.txt'):
                filename += '.txt'
            
            try:
                with open(filename, 'r', encoding='utf-8') as f:
                    lines = f.readlines()
                
                stock_codes = []
                for line in lines:
                    line = line.strip()
                    if line and not line.startswith('#'):
                        stock_codes.append(line)
                
                if not stock_codes:
                    print("❌ 文件中没有有效的股票代码")
                    continue
                
                # 验证股票代码格式
                valid_codes = []
                for code in stock_codes:
                    if code.isdigit() and len(code) == 6:
                        valid_codes.append(code)
                    else:
                        print(f"⚠️  股票代码 {code} 格式不正确，已跳过")
                
                if not valid_codes:
                    print("❌ 文件中没有有效的股票代码")
                    continue
                
                return valid_codes
                
            except FileNotFoundError:
                print(f"❌ 文件 {filename} 不存在")
                continue
            except Exception as e:
                print(f"❌ 读取文件失败: {e}")
                continue
    
    def download_stock_data(self, stock_code, days=100, max_retries=5):
        """下载股票数据"""
        import time
        import requests
        from requests.adapters import HTTPAdapter
        from urllib3.util.retry import Retry
        
        # 配置重试策略
        session = requests.Session()
        retry_strategy = Retry(
            total=2,
            backoff_factor=2,
            status_forcelist=[429, 500, 502, 503, 504],
            allowed_methods=["HEAD", "GET", "OPTIONS"]
        )
        adapter = HTTPAdapter(max_retries=retry_strategy)
        session.mount("http://", adapter)
        session.mount("https://", adapter)
        
        for attempt in range(max_retries):
            try:
                if attempt > 0:
                    print(f"正在重试下载股票 {stock_code} 的数据... (第 {attempt + 1} 次)")
                    # 递增等待时间，并添加随机抖动
                    import random
                    wait_time = 8 * attempt + random.uniform(1, 3)
                    print(f"⏳ 等待 {wait_time:.1f} 秒后重试...")
                    time.sleep(wait_time)
                else:
                    print(f"正在下载股票 {stock_code} 的数据...")
                
                # 计算日期范围（最近100个交易日）
                end_date = datetime.now()
                start_date = end_date - timedelta(days=days*2)  # 多取一些天数确保有足够的交易日
                
                print(f"   请求日期范围: {start_date.strftime('%Y-%m-%d')} 至 {end_date.strftime('%Y-%m-%d')}")
                
                # 添加请求前的短暂延迟，避免请求过于频繁
                if attempt > 0:
                    time.sleep(2)
                
                # 使用akshare下载数据
                data = ak.stock_zh_a_hist(
                    symbol=stock_code,
                    period="daily",
                    start_date=start_date.strftime('%Y%m%d'),
                    end_date=end_date.strftime('%Y%m%d'),
                    adjust="qfq"  # 前复权
                )
                
                if data.empty:
                    print(f"❌ 股票 {stock_code}: 未找到数据")
                    return None
                
                # 重命名列以匹配Kronos格式
                data = data.rename(columns={
                    '日期': 'timestamps',
                    '开盘': 'open',
                    '收盘': 'close',
                    '最高': 'high',
                    '最低': 'low',
                    '成交量': 'volume',
                    '成交额': 'amount'
                })
                
                # 设置日期为索引
                data['timestamps'] = pd.to_datetime(data['timestamps'])
                data = data.set_index('timestamps')
                
                # 只取最近100个交易日
                if len(data) > days:
                    data = data.tail(days)
                
                print(f"✅ 股票 {stock_code}: 成功下载 {len(data)} 条记录")
                print(f"   数据范围: {data.index[0].strftime('%Y-%m-%d')} 至 {data.index[-1].strftime('%Y-%m-%d')}")
                
                return data
                
            except Exception as e:
                error_msg = str(e)
                print(f"❌ 股票 {stock_code}: 下载失败 (第 {attempt + 1} 次) - {error_msg}")
                
                # 分析错误类型
                if "Connection reset by peer" in error_msg:
                    print("   🔍 分析: 连接被服务器重置，可能是请求过于频繁")
                elif "timeout" in error_msg.lower():
                    print("   🔍 分析: 请求超时，网络可能较慢")
                elif "Connection aborted" in error_msg:
                    print("   🔍 分析: 连接被中断，可能是网络不稳定")
                
                if attempt == max_retries - 1:
                    print(f"❌ 股票 {stock_code}: 经过 {max_retries} 次尝试后仍然失败")
                    print("💡 建议:")
                    print("   1. 检查网络连接是否稳定")
                    print("   2. 稍后重试（服务器可能负载较高）")
                    print("   3. 确认股票代码是否正确")
                    print("   4. 尝试使用其他网络环境")
                    return None
                else:
                    # 更长的等待时间
                    wait_time = 8 * (attempt + 1)
                    print(f"⏳ 等待 {wait_time} 秒后重试...")
                    time.sleep(wait_time)
        
        return None
    
    def prepare_prediction_data(self, data, lookback_days=100, pred_days=30):
        """准备预测数据"""
        try:
            # 确保数据长度足够
            if len(data) < lookback_days:
                print(f"⚠️  数据长度不足，需要 {lookback_days} 天，实际只有 {len(data)} 天")
                lookback_days = len(data)
            
            # 准备历史数据
            x_df = data.tail(lookback_days)[['open', 'high', 'low', 'close', 'volume', 'amount']].copy()
            x_timestamp = data.tail(lookback_days).index
            
            # 生成未来预测时间戳（交易日，排除周末和节假日）
            last_date = x_timestamp[-1]
            future_trading_days = self.generate_trading_days(last_date + timedelta(days=1), pred_days)
            
            y_timestamp = pd.Series(future_trading_days)
            
            # 确保时间戳是Series格式
            x_timestamp = pd.Series(x_timestamp)
            
            print(f"📅 预测期间: {future_trading_days[0].strftime('%Y-%m-%d')} 至 {future_trading_days[-1].strftime('%Y-%m-%d')}")
            print(f"📅 预测天数: {len(future_trading_days)} 个交易日")
            
            return x_df, x_timestamp, y_timestamp
            
        except Exception as e:
            print(f"❌ 数据准备失败: {str(e)}")
            return None, None, None
    
    def make_prediction(self, x_df, x_timestamp, y_timestamp, pred_len=30):
        """进行预测"""
        try:
            print("正在进行预测...")
            
            pred_df = self.predictor.predict(
                df=x_df,
                x_timestamp=x_timestamp,
                y_timestamp=y_timestamp,
                pred_len=pred_len,
                T=1.0,
                top_p=0.9,
                sample_count=1,
                verbose=True
            )
            
            print("✅ 预测完成!")
            return pred_df
            
        except Exception as e:
            print(f"❌ 预测失败: {str(e)}")
            return None
    
    def plot_prediction(self, stock_code, historical_data, pred_data, x_timestamp, y_timestamp):
        """绘制预测结果"""
        try:
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 10), sharex=True)
            
            # 创建连续的时间轴索引
            x_indices = range(len(x_timestamp))
            y_indices = range(len(x_timestamp), len(x_timestamp) + len(y_timestamp))
            
            # 绘制价格图
            ax1.plot(x_indices, historical_data['close'], label='历史价格', color='blue', linewidth=2)
            ax1.plot(y_indices, pred_data['close'], label='预测价格', color='red', linewidth=2, linestyle='--')
            ax1.set_ylabel('收盘价 (元)', fontsize=12)
            ax1.set_title(f'股票 {stock_code} 价格预测 (排除节假日)', fontsize=14, fontweight='bold')
            ax1.legend(fontsize=11)
            ax1.grid(True, alpha=0.3)
            
            # 绘制成交量图
            ax2.plot(x_indices, historical_data['volume'], label='历史成交量', color='blue', linewidth=2)
            ax2.plot(y_indices, pred_data['volume'], label='预测成交量', color='red', linewidth=2, linestyle='--')
            ax2.set_ylabel('成交量', fontsize=12)
            ax2.set_xlabel('交易日', fontsize=12)
            ax2.legend(fontsize=11)
            ax2.grid(True, alpha=0.3)
            
            # 设置x轴刻度
            total_days = len(x_timestamp) + len(y_timestamp)
            step = max(1, total_days // 12)  # 显示约12个标签
            tick_positions = list(range(0, total_days, step))
            
            # 创建标签：历史数据用实际日期，预测数据用预测日期
            tick_labels = []
            for pos in tick_positions:
                if pos < len(x_timestamp):
                    # 历史数据标签
                    tick_labels.append(x_timestamp.iloc[pos].strftime('%m-%d'))
                else:
                    # 预测数据标签
                    pred_pos = pos - len(x_timestamp)
                    if pred_pos < len(y_timestamp):
                        tick_labels.append(y_timestamp.iloc[pred_pos].strftime('%m-%d'))
                    else:
                        tick_labels.append('')
            
            ax2.set_xticks(tick_positions)
            ax2.set_xticklabels(tick_labels, rotation=45, ha='right')
            
            # 添加分隔线区分历史和预测数据
            split_point = len(x_timestamp) - 0.5
            ax1.axvline(x=split_point, color='gray', linestyle=':', alpha=0.7, linewidth=2)
            ax2.axvline(x=split_point, color='gray', linestyle=':', alpha=0.7, linewidth=2)
            
            # 添加文本标注
            ax1.text(0.02, 0.98, f'历史数据: {len(x_timestamp)} 个交易日', 
                    transform=ax1.transAxes, verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
            ax1.text(0.02, 0.88, f'预测数据: {len(y_timestamp)} 个交易日', 
                    transform=ax1.transAxes, verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.8))
            
            # 添加日期范围标注
            ax1.text(0.98, 0.02, f'历史: {x_timestamp.iloc[0].strftime("%Y-%m-%d")} 至 {x_timestamp.iloc[-1].strftime("%Y-%m-%d")}', 
                    transform=ax1.transAxes, verticalalignment='bottom', horizontalalignment='right',
                    bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8), fontsize=9)
            ax1.text(0.98, 0.12, f'预测: {y_timestamp.iloc[0].strftime("%Y-%m-%d")} 至 {y_timestamp.iloc[-1].strftime("%Y-%m-%d")}', 
                    transform=ax1.transAxes, verticalalignment='bottom', horizontalalignment='right',
                    bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8), fontsize=9)
            
            plt.tight_layout()
            
            # 保存图片
            output_dir = "prediction_results"
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            
            timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"{output_dir}/prediction_{stock_code}_{timestamp_str}.png"
            plt.savefig(filename, dpi=300, bbox_inches='tight')
            print(f"📊 预测图表已保存: {filename}")
            
            plt.show()
            
        except Exception as e:
            print(f"❌ 绘图失败: {str(e)}")
            import traceback
            traceback.print_exc()
    
    def save_prediction_results(self, stock_code, pred_data, y_timestamp):
        """保存预测结果"""
        try:
            output_dir = "prediction_results"
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            
            # 保存为CSV
            timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
            csv_filename = f"{output_dir}/prediction_{stock_code}_{timestamp_str}.csv"
            pred_data.to_csv(csv_filename, encoding='utf-8-sig')
            print(f"💾 预测结果已保存: {csv_filename}")
            
            # 保存为JSON
            json_filename = f"{output_dir}/prediction_{stock_code}_{timestamp_str}.json"
            pred_data.to_json(json_filename, orient='index', date_format='iso')
            print(f"💾 预测结果(JSON)已保存: {json_filename}")
            
        except Exception as e:
            print(f"❌ 保存结果失败: {str(e)}")
    
    def print_prediction_summary(self, stock_code, pred_data):
        """打印预测摘要"""
        print(f"\n📈 股票 {stock_code} 预测摘要")
        print("="*50)
        print(f"预测期间: {pred_data.index[0].strftime('%Y-%m-%d')} 至 {pred_data.index[-1].strftime('%Y-%m-%d')}")
        print(f"预测天数: {len(pred_data)} 个交易日")
        print()
        
        # 价格统计
        print("价格预测:")
        print(f"  起始价格: {pred_data['close'].iloc[0]:.2f}")
        print(f"  结束价格: {pred_data['close'].iloc[-1]:.2f}")
        print(f"  最高价格: {pred_data['high'].max():.2f}")
        print(f"  最低价格: {pred_data['low'].min():.2f}")
        print(f"  价格变化: {((pred_data['close'].iloc[-1] / pred_data['close'].iloc[0]) - 1) * 100:.2f}%")
        print()
        
        # 成交量统计
        print("成交量预测:")
        print(f"  平均成交量: {pred_data['volume'].mean():.0f}")
        print(f"  最大成交量: {pred_data['volume'].max():.0f}")
        print(f"  最小成交量: {pred_data['volume'].min():.0f}")
        print()
        
        # 显示前5天和后5天的预测
        print("预测详情 (前5天):")
        print(pred_data.head().round(2))
        print()
        print("预测详情 (后5天):")
        print(pred_data.tail().round(2))
    
    def run(self, test_mode=False, test_stock_codes=None):
        """运行主程序"""
        print("🚀 交互式股票预测程序")
        print("="*60)
        print("本程序使用Kronos模型预测股票未来走势")
        print("支持A股市场，预测未来30个交易日的价格和成交量")
        print()
        
        # 加载模型
        if not self.load_models():
            return
        
        # 获取股票代码
        if test_mode and test_stock_codes:
            stock_codes = test_stock_codes
            print(f"🧪 测试模式: 使用预设股票代码 {stock_codes}")
        else:
            stock_codes = self.get_stock_codes()
            if not stock_codes:
                print("❌ 未获取到有效的股票代码")
                return
        
        print(f"\n📊 将预测以下股票: {', '.join(stock_codes)}")
        
        # 对每只股票进行预测
        for i, stock_code in enumerate(stock_codes, 1):
            print(f"\n{'='*60}")
            print(f"正在处理股票 {i}/{len(stock_codes)}: {stock_code}")
            print('='*60)
            
            # 下载数据
            data = self.download_stock_data(stock_code, days=100)
            if data is None:
                continue
            
            # 准备预测数据
            x_df, x_timestamp, y_timestamp = self.prepare_prediction_data(data, lookback_days=100, pred_days=30)
            if x_df is None:
                continue
            
            # 进行预测
            pred_data = self.make_prediction(x_df, x_timestamp, y_timestamp, pred_len=30)
            if pred_data is None:
                continue
            
            # 打印预测摘要
            self.print_prediction_summary(stock_code, pred_data)
            
            # 绘制预测图
            self.plot_prediction(stock_code, x_df, pred_data, x_timestamp, y_timestamp)
            
            # 保存预测结果
            self.save_prediction_results(stock_code, pred_data, y_timestamp)
            
            print(f"✅ 股票 {stock_code} 预测完成!")
        
        print(f"\n🎉 所有股票预测完成!")
        print("预测结果已保存到 prediction_results 目录")

def main():
    """主函数"""
    predictor = InteractiveStockPredictor()
    
    # 检查是否为测试模式
    import sys
    if len(sys.argv) > 1 and sys.argv[1] == '--test':
        # 测试模式：使用示例股票代码
        test_codes = ['600036', '000001']  # 招商银行、平安银行
        predictor.run(test_mode=True, test_stock_codes=test_codes)
    elif len(sys.argv) > 1:
        # 命令行模式：直接指定股票代码
        stock_codes = sys.argv[1:]
        # 验证股票代码格式
        valid_codes = []
        for code in stock_codes:
            if code.isdigit() and len(code) == 6:
                valid_codes.append(code)
            else:
                print(f"⚠️  股票代码 {code} 格式不正确，已跳过")
        
        if valid_codes:
            print(f"📊 将预测以下股票: {', '.join(valid_codes)}")
            predictor.run(test_mode=True, test_stock_codes=valid_codes)
        else:
            print("❌ 没有有效的股票代码")
    else:
        # 正常交互模式
        predictor.run()

if __name__ == "__main__":
    main()
